/**
 * Copyright 2023-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "framework/model/unified_model/model_partitions.h"

#include "infra/base/assertion.h"

#include "framework/infra/log/log.h"
#include "common/math/math_util.h"

using std::string;
namespace hiai {

Status ModelPartitions::Init(const uint8_t* partitionsData, size_t partitionsSize, bool isPassedFromHidl)
{
    // check original memory partition table, data and size
    HIAI_EXPECT_TRUE((partitionsData != nullptr) && (partitionsSize > sizeof(ModelPartitionTable)));
    const ModelPartitionTable* partitionTableOrigin = reinterpret_cast<const ModelPartitionTable*>(partitionsData);
    uint32_t partitionTableSizeOrigin = SIZE_OF_MODEL_PARTITION_TABLE(*partitionTableOrigin);

    // allocate and copy partition table to new memory
    std::unique_ptr<uint8_t[]> cpyPartitionTable(new (std::nothrow) uint8_t[partitionTableSizeOrigin]);
    HIAI_EXPECT_NOT_NULL(cpyPartitionTable);
    HIAI_EXPECT_TRUE(memcpy_s(cpyPartitionTable.get(), partitionTableSizeOrigin,
        partitionTableOrigin, partitionTableSizeOrigin) == 0);

    // get copied partition table base address and size
    const ModelPartitionTable* partitionTable = reinterpret_cast<const ModelPartitionTable*>(cpyPartitionTable.get());
    HIAI_EXPECT_NOT_NULL(partitionTable);
    uint32_t partitionTableSize = SIZE_OF_MODEL_PARTITION_TABLE(*partitionTable);

    // check the partition table size again with the original memory data size
    HIAI_EXPECT_TRUE(partitionTableSize == partitionTableSizeOrigin);

    // check new memory partition table is valid
    HIAI_EXPECT_TRUE(CheckModelPartitionTableValid(partitionsSize, partitionTable, partitionTableSize));

    // parse model partition data
    HIAI_EXPECT_EXEC(ParseModelPartitionTable(partitionsData, partitionTable, partitionTableSize, isPassedFromHidl));
    return SUCCESS;
}

Maybe<ModelPartition> ModelPartitions::GetPartition(ModelPartitionType type, int8_t index) const
{
    int8_t partitionCount = 0;
    for (const ModelPartition& part : modelPartitions_) {
        if (part.type == type) {
            partitionCount++;
            if (index == -1 || partitionCount == index) {
                return Maybe<ModelPartition>(part);
            }
        }
    }
    return Maybe<ModelPartition>(NULL_MAYBE);
}

bool ModelPartitions::CheckModelPartitionTableValid(
    uint32_t partitionsSize, const ModelPartitionTable* mpTable, uint32_t mpTableSize)
{
    if (mpTable->num > MAX_PARTITION_NUM || partitionsSize < mpTableSize) {
        FMK_LOGE("ERROR: params are invalid! mpTable->num is :%d, partitionsSize is %d, mpTableSize is %d",
            mpTable->num, partitionsSize, mpTableSize);
        return false;
    }

    uint32_t dataLength = 0;
    for (uint32_t i = 0; i < mpTable->num; i++) {
        HIAI_EXPECT_EXEC_R(CheckUint32AddOverflow(dataLength, mpTable->partition[i].memSize), false);
        dataLength += mpTable->partition[i].memSize;
    }
    if (partitionsSize != (mpTableSize + dataLength)) {
        FMK_LOGE("invalid partition size");
        return false;
    }

    // check MODEL_DEF partition must exist
    bool isModelDefExist = false;
    for (uint32_t i = 0; i < mpTable->num; i++) {
        const ModelPartitionMemInfo& partition = mpTable->partition[i];
        if (partition.type == ModelPartitionType::MODEL_DEF) {
            isModelDefExist = true;
            break;
        }
    }
    if (!isModelDefExist) {
        FMK_LOGE("ModelPartition of type MODEL_DEF is not exist");
        return false;
    }
    return true;
}

Status ModelPartitions::ParseModelPartitionTable(
    const uint8_t* partitionsData, const ModelPartitionTable* mpTable, uint32_t mpTableSize, bool isPassedFromHidl)
{
    uint32_t memOffset = mpTableSize;
#if defined(AI_SUPPORT_32_BIT_OS)
    uint32_t dataLength = 0;
#endif
    for (uint32_t i = 0; i < mpTable->num; i++) {
        ModelPartition partition;
        partition.size = mpTable->partition[i].memSize;
        partition.type = mpTable->partition[i].type;
        partition.data = const_cast<uint8_t*>(partitionsData + memOffset);
        partition.isPassedFromHidl = isPassedFromHidl;
#if defined(AI_SUPPORT_32_BIT_OS)
        if (partition.type == ModelPartitionType::MODEL_DEF) {
            if (partition.size % 4 != 0) {
                dataLength = ((partition.size + 3) / 4) * 4 - partition.size;
                memOffset += dataLength;
            }
        }

        if (partition.type == ModelPartitionType::WEIGHTS_DATA) {
            partition.size = mpTable->partition[i].memSize - dataLength;
        }
#endif
        memOffset += partition.size;
        if (partition.type > ModelPartitionType::AIPP_CUSTOM_INFO) {
            continue;
        }
        modelPartitions_.push_back(partition);
    }
    return SUCCESS;
}
} // namespace hiai
